# Copyright 2019 Google LLC (original)
# Copyright 2019 Uizard Technologies (small modifications)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ============================================================================
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Training loop, checkpoint saving and loading, evaluation code."""
from npu_bridge.npu_init import *
from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig

import json
import os.path
import shutil

from absl import flags
from easydict import EasyDict
from tqdm import trange
import time
from libml import data, utils
import numpy as np
import tensorflow as tf
# from tensorflow.keras.utils import np_utils
from sklearn.metrics import f1_score, recall_score, precision_score
#import moxing as mox


FLAGS = flags.FLAGS
flags.DEFINE_string('train_dir', './experiments',
                    'Folder where to save training data.')
flags.DEFINE_float('lr', 0.0001, 'Learning rate.')
flags.DEFINE_integer('batch', 64, 'Batch size.')
flags.DEFINE_integer('train_kimg', 1 << 14, 'Training duration in kibi-samples.')
flags.DEFINE_integer('report_kimg', 64, 'Report summary period in kibi-samples.')
flags.DEFINE_integer('save_kimg', 64, 'Save checkpoint period in kibi-samples.')
flags.DEFINE_integer('keep_ckpt', 50, 'Number of checkpoints to keep.')
flags.DEFINE_string('eval_ckpt', '', 'Checkpoint to evaluate. If provided, do not do training, just do eval.')
flags.DEFINE_integer('epochs', 700, 'Number of epochs.')
flags.DEFINE_integer('epochs_ctrl', 0, 'Number of control.')
flags.DEFINE_integer('steps_ctrl', 1024, 'Number of control.')
flags.DEFINE_string('class_mapping', '', 'Name of file containing class mappings generated by scripts/create_datasets.py')
flags.DEFINE_string(
    "load_checkpoint",
    "",
    "Checkpoint file to start training from (e.g. "
    ".../model.ckpt-354615), or None for random init",)


class Model:
    def __init__(self, train_dir: str, dataset: data.DataSet, **kwargs):
        self.train_dir = os.path.join(train_dir, self.experiment_name(**kwargs))
        self.params = EasyDict(kwargs)
        self.dataset = dataset
        self.graph = tf.get_default_graph()
        # self.graph = tf.Graph()
        self.session = None
        self.tmp = EasyDict(print_queue=[], cache=EasyDict())
        with self.graph.as_default():
            self.step = tf.train.get_or_create_global_step()
            self.ops = self.model(**kwargs)
            self.ops.update_step = tf.assign_add(self.step, FLAGS.batch)
            # self.add_summaries(**kwargs)

        print(' Config '.center(80, '-'))
        print('train_dir', self.train_dir)
        print('%-32s %s' % ('Model', self.__class__.__name__))
        print('%-32s %s' % ('Dataset', dataset.name))
        for k, v in sorted(kwargs.items()):
            print('%-32s %s' % (k, v))
        print(' Model '.center(80, '-'))
        with self.graph.as_default():
            to_print = [tuple(['%s' % x for x in (v.name, np.prod(v.shape), v.shape)]) for v in utils.model_vars(None)]
        to_print.append(('Total', str(sum(int(x[1]) for x in to_print)), ''))
        sizes = [max([len(x[i]) for x in to_print]) for i in range(3)]
        fmt = '%%-%ds  %%%ds  %%%ds' % tuple(sizes)
        for x in to_print[:-1]:
            print(fmt % x)
        print()
        print(fmt % to_print[-1])
        print('-' * 80)
        self._create_initial_files()

    @property
    def arg_dir(self):
        return os.path.join(self.train_dir, 'args')

    @property
    def checkpoint_dir(self):
        return os.path.join(self.train_dir, 'tf')

    def train_print(self, text):
        self.tmp.print_queue.append(text)

    def _create_initial_files(self):
        for dir in (self.checkpoint_dir, self.arg_dir):
            if not os.path.exists(dir):
                os.makedirs(dir)
        self.save_args()

    def _reset_files(self):
        shutil.rmtree(self.train_dir)
        self._create_initial_files()

    def save_args(self, **extra_params):
        with open(os.path.join(self.arg_dir, 'args.json'), 'w') as f:
            json.dump({**self.params, **extra_params}, f, sort_keys=True, indent=4)

    @classmethod
    def load(cls, train_dir):
        with open(os.path.join(train_dir, 'args/args.json'), 'r') as f:
            params = json.load(f)
        instance = cls(train_dir=train_dir, **params)
        instance.train_dir = train_dir
        return instance

    def experiment_name(self, **kwargs):
        args = [x + str(y) for x, y in sorted(kwargs.items())]
        return '_'.join([self.__class__.__name__] + args)

    def eval_mode(self, ckpt=None):
        with self.graph.as_default():
            self.session = tf.Session(config=npu_config_proto(config_proto=utils.get_config()))
            saver = tf.train.Saver()
            # saver = tf.train.import_meta_graph('/data/huawei_demo/tensorflow/realmix/realmix-master/experiments/cifar10_aug50.1@250-500/RealMix_archresnet_augmentationcifar10_batch64_beta0.75_ema0.999_filters32_lr0.002_nclass10_ood_mask-1.0_repeat4_scales3_tsalinear_schedule_w_match75.0_wd0.02/tf/model.ckpt-12386304.meta')
            if ckpt is None:
                ckpt = utils.find_latest_checkpoint(self.checkpoint_dir)
            else:
                ckpt = os.path.abspath(ckpt)
            saver.restore(self.session, ckpt)
            self.tmp.step = self.session.run(self.step)
        print('Eval model %s at global_step %d' % (self.__class__.__name__, self.tmp.step))
        return self

    def model(self, **kwargs):
        raise NotImplementedError()

    def add_summaries(self, **kwargs):
        raise NotImplementedError()


class ClassifySemi(Model):
    """Semi-supervised classification."""

    def __init__(self, train_dir: str, dataset: data.DataSet, nclass: int, **kwargs):
        self.nclass = nclass
        Model.__init__(self, train_dir, dataset, nclass=nclass, **kwargs)

    def train_step(self, train_session, data_labeled, data_unlabeled):
        start_time = time.time()
        x, y = self.session.run([data_labeled, data_unlabeled])
        self.tmp.step = train_session.run([self.ops.train_op, self.ops.update_step],
                                          feed_dict={self.ops.x: x['image'],
                                                     self.ops.y: y['image'],
                                                     self.ops.label: x['label']})[1]
        perf = time.time() - start_time
        fps = 64 / perf
        print('perf: {:.2f} fps {:.2f}'.format(perf,fps))
        # lossscale
        lossScale = tf.get_default_graph().get_tensor_by_name("loss_scale:0")
        overflow_status_reduce_all = tf.get_default_graph().get_tensor_by_name("overflow_status_reduce_all:0")
        
        try:
            l_s, overflow_status_reduce_all,self.tmp.step = train_session.run([lossScale,overflow_status_reduce_all,self.ops.update_step,self.ops.train_op],
                                          feed_dict={self.ops.x: x['image'],
                                                     self.ops.y: y['image'],
                                                     self.ops.label: x['label']})[0:3]
            print('loss_scale is: ', l_s)
            print("overflow_status_reduce_all:", overflow_status_reduce_all)
            print("global_step:", self.tmp.step)
        except TypeError:
            tmep = train_session.run([lossScale,overflow_status_reduce_all,self.ops.update_step,self.ops.train_op],
                                     feed_dict={self.ops.x: x['image'],
                                                self.ops.y: y['image'],
                                                self.ops.label: x['label']})
            print(tmep)

    def train(self, train_nimg, report_nimg,batch):
        # print ("断电",FLAGS.eval_ckpt)
        if FLAGS.eval_ckpt:
            self.eval_checkpoint(FLAGS.eval_ckpt)
            return
        with self.graph.as_default():


            train_labeled = self.dataset.train_labeled.batch(batch, drop_remainder=True).prefetch(16)
            train_labeled = train_labeled.make_one_shot_iterator().get_next()
            train_unlabeled = self.dataset.train_unlabeled.batch(batch, drop_remainder=True).prefetch(16)
            train_unlabeled = train_unlabeled.make_one_shot_iterator().get_next()

            saver = tf.train.Saver(max_to_keep=FLAGS.keep_ckpt, pad_step_number=10)
            if FLAGS.load_checkpoint:
                vars_to_load = [
                v for v in tf.all_variables() if "dense" not in v.name and "global_step" not in v.name]
                print([v.name for v in vars_to_load])
                finetuning_saver = tf.train.Saver(var_list=vars_to_load, max_to_keep=FLAGS.keep_ckpt, pad_step_number=10)

            def init_fn(_, sess):
                if FLAGS.load_checkpoint:
                    tf.logging.info(
                        "Fine tuning from checkpoint: %s", FLAGS.load_checkpoint
                    )
                    finetuning_saver.restore(sess, FLAGS.load_checkpoint)

            scaffold = tf.train.Scaffold(saver=saver, init_fn=init_fn)

            ##添加
            config_proto = tf.ConfigProto()
            custom_op =  config_proto.graph_options.rewrite_options.custom_optimizers.add()
            custom_op.name =  "NpuOptimizer"
            custom_op.parameter_map["use_off_line"].b = True #在昇腾AI处理器执行训练

            # 屏蔽融合规则
            custom_op.parameter_map["fusion_switch_file"].s = tf.compat.as_bytes("fusion_switch.cfg")

            # custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision")

            # dump_path：dump数据存放路径，该参数指定的目录需要在启动训练的环境上（容器或Host侧）提前创建且确保安装时配置的运行用户具有读写权限
            # custom_op.parameter_map["dump_path"].s = tf.compat.as_bytes(self.checkpoint_dir)
            # # enable_dump_debug：是否开启溢出检测功能
            # custom_op.parameter_map["enable_dump_debug"].b = True
            # # dump_debug_mode：溢出检测模式，取值：all/aicore_overflow/atomic_overflow
            # custom_op.parameter_map["dump_debug_mode"].s = tf.compat.as_bytes("all")

            # 黑名单
            # custom_op.parameter_map["modify_mixlist"].s = tf.compat.as_bytes("/home/ma-user/modelarts/user-job-dir/code/ops_info.json")

            config = npu_config_proto(config_proto=config_proto)

            with tf.train.MonitoredTrainingSession(
                    scaffold=scaffold,
                    checkpoint_dir=self.checkpoint_dir,
                    config=config,
                    save_checkpoint_steps=FLAGS.save_kimg << 10,
                    save_summaries_steps=report_nimg - batch) as train_session:
                print("Training...")
                self.session = train_session._tf_sess()
                self.tmp.step = self.session.run(self.step)

                if FLAGS.epochs < 1 + (self.tmp.step // report_nimg):
                    print("Training of " + str(FLAGS.epochs) + " epochs complete.")
                    return
                #self.tmp.step  131072累加
                while self.tmp.step < train_nimg - FLAGS.epochs_ctrl:        #67108864-67043328=65536
                    
                    if FLAGS.epochs < 1 + (self.tmp.step // report_nimg):   #report_nimg=65536
                        print("Training of " + str(FLAGS.epochs) + " epochs complete.")
                        
                        #mox.file.copy_parallel(self.checkpoint_dir,'obs://realmixdata/ckpt')
                        print("tranport ok!")
                        return

                    loop = trange(self.tmp.step % report_nimg, report_nimg, batch,
                                  leave=False, unit='img', unit_scale=batch,
                                  desc='Epoch %d/%d' % (1 + (self.tmp.step // report_nimg), train_nimg // report_nimg))
                    q = 0
                    for _ in loop:
                        q = q + 1
                        if q > FLAGS.steps_ctrl:   #10
                            break
                        self.train_step(train_session, train_labeled, train_unlabeled)   #1024*128=131072
                        while self.tmp.print_queue:
                            loop.write(self.tmp.print_queue.pop(0)) 
                while self.tmp.print_queue:
                    print(self.tmp.print_queue.pop(0))

    def tune(self, train_nimg):
        batch = FLAGS.batch
        with self.graph.as_default():
            train_labeled = self.dataset.train_labeled.batch(batch, drop_remainder=True).prefetch(16)
            train_labeled = train_labeled.make_one_shot_iterator().get_next()
            train_unlabeled = self.dataset.train_unlabeled.batch(batch, drop_remainder=True).prefetch(16)
            train_unlabeled = train_unlabeled.make_one_shot_iterator().get_next()

            for _ in trange(0, train_nimg, batch, leave=False, unit='img', unit_scale=batch, desc='Tuning'):
                x, y = self.session.run([train_labeled, train_unlabeled])
                self.session.run([self.ops.tune_op], feed_dict={self.ops.x: x['image'],
                                                                self.ops.y: y['image'],
                                                                self.ops.label: x['label']})

    def eval_checkpoint(self, ckpt=None):
        # self.eval_mode(ckpt)
        self.eval_mode()
        raw = self.eval_stats(classify_op=self.ops.classify_raw)
        ema = self.eval_stats(classify_op=self.ops.classify_op)
        self.tune(16384)
        tuned_raw = self.eval_stats(classify_op=self.ops.classify_raw)
        tuned_ema = self.eval_stats(classify_op=self.ops.classify_op)
        print('%16s %8s %8s %8s' % ('', 'labeled', 'valid', 'test'))
        print('%16s %8s %8s %8s' % (('raw',) + tuple('%.2f' % x for x in raw)))
        print('%16s %8s %8s %8s' % (('ema',) + tuple('%.2f' % x for x in ema)))
        print('%16s %8s %8s %8s' % (('tuned_raw',) + tuple('%.2f' % x for x in tuned_raw)))
        print('%16s %8s %8s %8s' % (('tuned_ema',) + tuple('%.2f' % x for x in tuned_ema)))

    def eval_stats(self, batch=None, feed_extra=None, classify_op=None):
        def collect_samples(data):
            data_it = data.batch(1, drop_remainder=True).prefetch(16).make_one_shot_iterator().get_next()
            images, labels = [], []
            while 1:
                try:
                    v = sess_data.run(data_it)
                except tf.errors.OutOfRangeError:
                    break
                images.append(v['image'])
                labels.append(v['label'])

            images = np.concatenate(images, axis=0)
            labels = np.concatenate(labels, axis=0)
            return images, labels

        if 'test' not in self.tmp.cache:
            with  tf.Session(config=npu_config_proto(config_proto=utils.get_config())) as sess_data:
                self.tmp.cache.test = collect_samples(self.dataset.test)
                self.tmp.cache.valid = collect_samples(self.dataset.valid)
                self.tmp.cache.train_labeled = collect_samples(self.dataset.eval_labeled)

        # batch = batch or FLAGS.batch
        batch = 50
        classify_op = self.ops.classify_op if classify_op is None else classify_op
        eval_loss_op = self.ops.classify_op
        stats = []
        ood_masks = []

        for subset in ('train_labeled', 'valid', 'test'):
            images, labels = self.tmp.cache[subset]
            predicted = np.concatenate([
                self.session.run(classify_op, feed_dict={
                    self.ops.x_pre: images[x:x + batch], **(feed_extra or {})})
                for x in range(0, images.shape[0], batch)
            ], axis=0)
            stats.append((predicted.argmax(1) == labels).mean() * 100)
            if subset == 'test':
                ood_masks = [np.amax(predicted, axis=-1)[i] for i in range(10)]

        self.train_print('kimg %-5d  accuracy train/valid/test  %.2f  %.2f  %.2f' %
                         tuple([self.tmp.step >> 10] + stats[:3]))
        self.train_print('ood %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f' % tuple(ood_masks))
        # return np.array(stats+ood_masks, 'f')
        return np.array(stats, 'f')

    def eval_class_stats(self, batch=None, feed_extra=None, classify_op=None):
        def collect_samples(data):
            data_it = data.batch(1, drop_remainder=True).prefetch(16).make_one_shot_iterator().get_next()
            images, labels = [], []
            while 1:
                try:
                    v = sess_data.run(data_it)
                except tf.errors.OutOfRangeError:
                    break
                images.append(v['image'])
                labels.append(v['label'])

            images = np.concatenate(images, axis=0)
            labels = np.concatenate(labels, axis=0)
            return images, labels
        if 'test' not in self.tmp.cache:
            with tf.Graph().as_default(), tf.Session(config=npu_config_proto(config_proto=utils.get_config())) as sess_data:
                self.tmp.cache.test = collect_samples(self.dataset.test)
                self.tmp.cache.valid = collect_samples(self.dataset.valid)
                self.tmp.cache.train_labeled = collect_samples(self.dataset.eval_labeled)

        batch = batch or FLAGS.batch
        classify_op = self.ops.classify_op if classify_op is None else classify_op
        class_accuracies = [[] for i in range(FLAGS.nclass)]
        mean_class_accuracies = []
        subset = 'test'
        images, labels = self.tmp.cache[subset]

        predicted = np.concatenate([
            self.session.run(classify_op, feed_dict={
                self.ops.x: images[x:x + batch], **(feed_extra or {})})
            for x in range(0, images.shape[0], batch)
        ], axis=0)

        assert max(labels) <= FLAGS.nclass, "Please provide the correct number of class in --nclass"

        for i, label in enumerate(labels):
            if label == predicted.argmax(1)[i]:
                class_accuracies[label].append(1)
            else:
                class_accuracies[label].append(0)
        mean_class_accuracies = [np.asarray(accuracy*100).mean() for accuracy in class_accuracies]

        return np.array(mean_class_accuracies, 'f')

    def eval_f1_pr(self, batch=None, feed_extra=None, classify_op=None):
        def collect_samples(data):
            data_it = data.batch(1, drop_remainder=True).prefetch(16).make_one_shot_iterator().get_next()
            images, labels = [], []
            while 1:
                try:
                    v = sess_data.run(data_it)
                except tf.errors.OutOfRangeError:
                    break
                images.append(v['image'])
                labels.append(v['label'])

            images = np.concatenate(images, axis=0)
            labels = np.concatenate(labels, axis=0)
            return images, labels
        if 'test' not in self.tmp.cache:
            with tf.Graph().as_default(), tf.Session(config=npu_config_proto(config_proto=utils.get_config())) as sess_data:
                self.tmp.cache.test = collect_samples(self.dataset.test)
                self.tmp.cache.valid = collect_samples(self.dataset.valid)
                self.tmp.cache.train_labeled = collect_samples(self.dataset.eval_labeled)

        batch = batch or FLAGS.batch
        classify_op = self.ops.classify_op if classify_op is None else classify_op
        accuracies = []
        class_accuracies = [[] for i in range(FLAGS.nclass)]
        mean_class_accuracies = []
        subset = 'test'
        images, labels = self.tmp.cache[subset]
        scores = []

        predicted = np.concatenate([
            self.session.run(classify_op, feed_dict={
                self.ops.x: images[x:x + batch], **(feed_extra or {})})
            for x in range(0, images.shape[0], batch)
        ], axis=0)

        scores.append(f1_score(labels, predicted.argmax(1), average='macro'))
        scores.append(precision_score(labels, predicted.argmax(1), average='macro'))
        scores.append(recall_score(labels, predicted.argmax(1), average='macro'))
        scores.append(f1_score(labels, predicted.argmax(1), average='weighted'))
        scores.append(precision_score(labels, predicted.argmax(1), average='weighted'))
        scores.append(recall_score(labels, predicted.argmax(1), average='weighted'))

        return np.array(np.asarray(scores), 'f')

    def eval_val_loss(self, batch=None, feed_extra=None, classify_op=None):
        def collect_samples(data):
            data_it = data.batch(1, drop_remainder=True).prefetch(16).make_one_shot_iterator().get_next()
            images, labels = [], []
            while 1:
                try:
                    v = sess_data.run(data_it)
                except tf.errors.OutOfRangeError:
                    break
                images.append(v['image'])
                labels.append(v['label'])

            images = np.concatenate(images, axis=0)
            labels = np.concatenate(labels, axis=0)
            return images, labels
        if 'test' not in self.tmp.cache:
            with tf.Graph().as_default(), tf.Session(config=npu_config_proto(config_proto=utils.get_config())) as sess_data:
                self.tmp.cache.test = collect_samples(self.dataset.test)
                self.tmp.cache.valid = collect_samples(self.dataset.valid)
                self.tmp.cache.train_labeled = collect_samples(self.dataset.eval_labeled)

        batch = batch or FLAGS.batch
        classify_op = self.ops.classify_op if classify_op is None else classify_op
        eval_loss_op = self.ops.eval_loss_op
        subset = 'valid'
        images, labels = self.tmp.cache[subset]

        val_loss = [self.session.run(eval_loss_op, feed_dict={
                self.ops.x: images, self.ops.label: labels})]

        return np.array(val_loss, 'f')


    def eval_ood(self, batch=None, feed_extra=None, classify_op=None):
        def collect_samples(data):
            data_it = data.batch(1, drop_remainder=True).prefetch(16).make_one_shot_iterator().get_next()
            images, labels = [], []
            while 1:
                try:
                    v = sess_data.run(data_it)
                except tf.errors.OutOfRangeError:
                    break
                images.append(v['image'])
                labels.append(v['label'])

            images = np.concatenate(images, axis=0)
            labels = np.concatenate(labels, axis=0)
            return images, labels
        if 'valid' not in self.tmp.cache:
            with tf.Graph().as_default(), tf.Session(config=npu_config_proto(config_proto=utils.get_config())) as sess_data:
                self.tmp.cache.test = collect_samples(self.dataset.test)
                self.tmp.cache.valid = collect_samples(self.dataset.valid)
                self.tmp.cache.train_labeled = collect_samples(self.dataset.eval_labeled)

        batch = batch or FLAGS.batch
        classify_op = self.ops.classify_op if classify_op is None else classify_op
        eval_loss_op = self.ops.eval_loss_op
        subset = 'valid'
        images, labels = self.tmp.cache[subset]

        logits = [
            self.session.run(classify_op, feed_dict={
                self.ops.x: images, **(feed_extra or {})})]

        max_probs = np.amax(logits, axis=-1)

        self.train_print(max_probs)
        self.train_print(labels)

        return np.array([], 'f')

    def add_summaries(self, feed_extra=None, **kwargs):
        del kwargs

        def gen_stats():
            return self.eval_stats(feed_extra=feed_extra)

        def gen_class_stats():
            return self.eval_class_stats(feed_extra=feed_extra)

        def gen_f1_pr():
            return self.eval_f1_pr(feed_extra=feed_extra)

        def gen_val_loss():
             return self.eval_val_loss(feed_extra=feed_extra)
        
        def gen_ood():
            return self.eval_ood(feed_extra=feed_extra)

        stats = tf.py_func(gen_stats, [], tf.float32)
        val_loss = tf.py_func(gen_val_loss, [], tf.float32)
        ood = tf.py_func(gen_ood, [], tf.float32)

        tf.summary.scalar('accuracy_train_labeled/', stats[0])
        tf.summary.scalar('accuracy_valid', stats[1])
        tf.summary.scalar('accuracy_test', stats[2])
        for i in range(10):
            tf.summary.scalar('ood_examples/' + str(i+1), stats[3+i])
        tf.summary.scalar('losses/val_loss', val_loss[0])

        mean_class_accuracies = tf.py_func(gen_class_stats, [], tf.float32)
        f1_pr_scores = tf.py_func(gen_f1_pr, [], tf.float32)

        tf.summary.scalar('f1_score/macro', f1_pr_scores[0])
        tf.summary.scalar('precision/macro', f1_pr_scores[1])
        tf.summary.scalar('recall/macro', f1_pr_scores[2])
        tf.summary.scalar('f1_score/weighted', f1_pr_scores[3])
        tf.summary.scalar('precision/weighted', f1_pr_scores[4])
        tf.summary.scalar('recall/weighted', f1_pr_scores[5])

        def get_class_mapping(self, **kwargs):
            if FLAGS.class_mapping:
                json_filename= FLAGS.class_mapping
            elif '.' in FLAGS.dataset:
                json_filename = FLAGS.dataset.split(".")[0] + "_class_mappings.json"
            else:
                json_filename = FLAGS.dataset.split("-")[0] + "_class_mappings.json"

            assert os.path.isfile(json_filename), json_filename + \
                " not found. Please provide another class mapping."

            return json.loads(open(json_filename).read())

        if FLAGS.class_mapping and FLAGS.class_mapping != '':
            json_data = get_class_mapping(self)

            for index in range(FLAGS.nclass):
                tf.summary.scalar('accuracy_test/' + json_data[str(index)], mean_class_accuracies[index])

